{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.TSPerceiver"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TSPerceiver"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This implementation is inspired by:\n",
    "\n",
    "Jaegle, A., Gimeno, F., Brock, A., Zisserman, A., Vinyals, O., & Carreira, J. (2021). \n",
    "\n",
    "<span style=\"color:dodgerblue\">**Perceiver: General Perception with Iterative Attention**</span>. arXiv preprint arXiv:2103.03206.\n",
    "\n",
    "Paper: https://arxiv.org/pdf/2103.03206.pdf\n",
    "\n",
    "Official repo: Not available as og April, 2021."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|exporti\n",
    "class ScaledDotProductAttention(Module):\n",
    "    def __init__(self, d_k:int, res_attention:bool=False): \n",
    "        self.d_k,self.res_attention = d_k,res_attention\n",
    "\n",
    "    def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n",
    "        '''\n",
    "        Input shape:\n",
    "            q               : [bs x n_heads x q_len x d_k]\n",
    "            k               : [bs x n_heads x d_k x seq_len]\n",
    "            v               : [bs x n_heads x seq_len x d_k]\n",
    "            key_padding_mask: [bs x seq_len]\n",
    "            attn_mask       : [seq_len x seq_len]\n",
    "\n",
    "        Output shape: \n",
    "            context: [bs x n_heads x q_len x d_v]\n",
    "            attn   : [bs x n_heads x q_len x seq_len]\n",
    "        '''\n",
    "\n",
    "        # MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n",
    "        scores = torch.matmul(q, k)                                   # scores : [bs x n_heads x q_len x seq_len]\n",
    "\n",
    "        # Scale\n",
    "        scores = scores / (self.d_k ** 0.5)\n",
    "\n",
    "        # Add previous scores (optional)\n",
    "        if prev is not None: scores = scores + prev\n",
    "\n",
    "        # Attention mask (optional)\n",
    "        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n",
    "            if attn_mask.dtype == torch.bool:\n",
    "                scores.masked_fill_(attn_mask, float('-inf'))\n",
    "            else:\n",
    "                scores += attn_mask\n",
    "\n",
    "        # Key padding mask (optional)\n",
    "        if key_padding_mask is not None:                              # key_padding_mask with shape [bs x seq_len]\n",
    "            scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))\n",
    "\n",
    "        # SoftMax\n",
    "        attn = F.softmax(scores, dim=-1)                               # attn   : [bs x n_heads x q_len x seq_len]\n",
    "\n",
    "        # MatMul (attn, v)\n",
    "        context = torch.matmul(attn, v)                                # context: [bs x n_heads x q_len x d_v]\n",
    "\n",
    "        if self.res_attention: return context, attn, scores\n",
    "        else: return context, attn\n",
    "\n",
    "\n",
    "class Attention(Module):\n",
    "    def __init__(self, d_latent:int, d_context:Optional[int]=None, n_heads:int=8, d_head:Optional[int]=None, attn_dropout:float=0., res_attention:bool=False):\n",
    "\n",
    "        d_context = ifnone(d_context, d_latent)\n",
    "        n_heads = ifnone(n_heads, 1)\n",
    "        d_head = ifnone(d_head, d_context//n_heads)\n",
    "\n",
    "        self.scale = d_head ** -0.5\n",
    "        self.n_heads, self.d_head, self.res_attention = n_heads, d_head, res_attention\n",
    "\n",
    "        self.to_q = nn.Linear(d_latent, d_head * n_heads, bias=False)\n",
    "        self.to_kv = nn.Linear(d_context, d_head * n_heads * 2, bias=False)\n",
    "\n",
    "        self.attn = ScaledDotProductAttention(d_k=d_head, res_attention=res_attention)\n",
    "\n",
    "        self.to_out = nn.Sequential(nn.Linear(d_head * n_heads, d_latent), nn.Dropout(attn_dropout))\n",
    "\n",
    "    def forward(self, x, context=None, mask=None):\n",
    "        h,d = self.n_heads, self.d_head\n",
    "        bs = x.shape[0]\n",
    "        q = self.to_q(x).view(bs, -1, h, d).transpose(1,2)\n",
    "        context = ifnone(context, x)\n",
    "        k, v = self.to_kv(context).chunk(2, dim=-1)\n",
    "        k = k.view(bs, -1, h, d).permute(0,2,3,1)\n",
    "        v = v.view(bs, -1, h, d).transpose(1,2)\n",
    "\n",
    "        if self.res_attention:\n",
    "            x, _, scores = self.attn(q, k, v)\n",
    "        else:\n",
    "            x, _ = self.attn(q, k, v)\n",
    "        x = x.permute(0, 2, 1, 3).reshape(bs, -1, h * d)\n",
    "\n",
    "        x = self.to_out(x)\n",
    "        if self.res_attention:\n",
    "            return x, scores\n",
    "        else: \n",
    "            return x\n",
    "\n",
    "\n",
    "class GEGLU(Module):\n",
    "    def forward(self, x):\n",
    "        x, gates = x.chunk(2, dim = -1)\n",
    "        return x * F.gelu(gates)\n",
    "\n",
    "\n",
    "class FeedForward(nn.Sequential):\n",
    "    def __init__(self, dim, mult=2, dropout=0.):\n",
    "        layers = [nn.Linear(dim, dim * mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim)]\n",
    "        # layers = [nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim)]\n",
    "        super().__init__(*layers)\n",
    "\n",
    "\n",
    "class CrossAttention(Module):\n",
    "    def __init__(self, d_latent, d_context=None, n_heads=8, d_head=None, attn_dropout=0., fc_dropout=0.):\n",
    "        d_context = ifnone(d_context, d_latent)\n",
    "        self.norm_latent= nn.LayerNorm(d_latent)\n",
    "        self.norm_context = nn.LayerNorm(d_context) if d_context is not None else None\n",
    "        self.attn = Attention(d_latent, d_context=d_context, n_heads=n_heads, d_head=d_head, attn_dropout=attn_dropout)\n",
    "        self.norm_ff = nn.LayerNorm(d_latent)\n",
    "        self.ff = FeedForward(d_latent, dropout=fc_dropout)\n",
    "    \n",
    "    def forward(self, x, context=None, mask=None):\n",
    "        x = self.norm_latent(x)\n",
    "        if context is not None: \n",
    "            context = self.norm_context(context)\n",
    "        context = ifnone(context, x)\n",
    "        x = self.attn(x, context)\n",
    "        x = self.norm_ff(x)\n",
    "        x = self.ff(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class LatentTransformer(Module):\n",
    "    def __init__(self, d_latent, n_heads=8, d_head=None, attn_dropout=0., fc_dropout=0., self_per_cross_attn=1):\n",
    "        self.layers = nn.ModuleList()\n",
    "        for _ in range(self_per_cross_attn):\n",
    "            self.layers.append(nn.ModuleList([nn.LayerNorm(d_latent), \n",
    "                                              Attention(d_latent, n_heads=n_heads, d_head=d_head, attn_dropout=attn_dropout), \n",
    "                                              nn.LayerNorm(d_latent) , \n",
    "                                              FeedForward(d_latent, dropout=fc_dropout)\n",
    "                                              ]))\n",
    "\n",
    "    def forward(self, x, mask=None):\n",
    "        for attn_norm, att, ff_norm, ff in self.layers:\n",
    "            x = attn_norm(x)\n",
    "            x = att(x)\n",
    "            x = ff_norm(x)\n",
    "            x = ff(x)\n",
    "        return x        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSPerceiver(Module):\n",
    "    def __init__(self, c_in, c_out, seq_len, cat_szs=0, n_cont=0, n_latents=512, d_latent=128, d_context=None, n_layers=6, self_per_cross_attn=1, \n",
    "                 share_weights=True, cross_n_heads=1, self_n_heads=8, d_head=None, attn_dropout=0., fc_dropout=0., concat_pool=False):\n",
    "        \n",
    "        d_context = ifnone(d_context, d_latent)\n",
    "        \n",
    "        # Embedding\n",
    "        self.to_ts_emb = nn.Linear(c_in, d_context)\n",
    "        self.to_cat_emb = nn.ModuleList([nn.Embedding(s, d_context) for s in cat_szs]) if cat_szs else None\n",
    "        self.to_cont_emb = nn.ModuleList([nn.Linear(1, d_context) for i in range(n_cont)]) if n_cont else None\n",
    "\n",
    "        self.latent_array = nn.Parameter(torch.zeros(1, n_latents, d_context)) # N = q_len = indices = n_latents  \n",
    "\n",
    "        # Positional encoding\n",
    "        # self.ts_pos_enc = nn.Parameter(torch.zeros(1, 1, d_context))\n",
    "        # self.cat_pos_enc = nn.Parameter(torch.zeros(1, 1, d_context)) if cat_szs else None\n",
    "        # self.cont_pos_enc = nn.Parameter(torch.zeros(1, 1, d_context)) if n_cont else None \n",
    "        self.ts_pos_enc = nn.Parameter(torch.zeros(1, 1, 1))\n",
    "        self.cat_pos_enc = nn.Parameter(torch.zeros(1, 1, 1)) if cat_szs else None\n",
    "        self.cont_pos_enc = nn.Parameter(torch.zeros(1, 1, 1)) if n_cont else None \n",
    "        # self.pos_enc = nn.Parameter(torch.zeros(1, seq_len + (len(cat_szs) if cat_szs else 0) + n_cont, d_context))\n",
    "        pos_enc = torch.linspace(-1, 1, seq_len + (len(cat_szs) if cat_szs else 0) + n_cont).unsqueeze(0).unsqueeze(-1).repeat(1, 1, d_context)\n",
    "        self.pos_enc = nn.Parameter(pos_enc, requires_grad=False)\n",
    "\n",
    "        # Cross-attention & Latent-transformer\n",
    "        self.self_per_cross_attn = self_per_cross_attn\n",
    "        self.attn = nn.ModuleList()\n",
    "        for i in range(n_layers):\n",
    "            if i < 2 or not share_weights: \n",
    "                attn = [CrossAttention(d_latent, d_context=d_context, n_heads=cross_n_heads, d_head=d_head, attn_dropout=attn_dropout, \n",
    "                                       fc_dropout=fc_dropout)]\n",
    "                if self_per_cross_attn != 0:\n",
    "                    attn += [LatentTransformer(d_latent, n_heads=self_n_heads, d_head=d_head, attn_dropout=attn_dropout, fc_dropout=fc_dropout, \n",
    "                                               self_per_cross_attn=self_per_cross_attn)]\n",
    "            self.attn.append(nn.ModuleList(attn))\n",
    "\n",
    "        self.head = nn.Sequential(GACP1d() if concat_pool else GAP1d(), nn.BatchNorm1d(d_latent*(1+concat_pool)), nn.Linear(d_latent*(1+concat_pool), c_out))\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Embedding\n",
    "        # Time series\n",
    "        if isinstance(x, tuple):\n",
    "            x_ts, (x_cat, x_cont) = x\n",
    "        else: \n",
    "            x_ts, x_cat, x_cont = x, None, None\n",
    "        context = self.to_ts_emb(x_ts.transpose(1,2))\n",
    "        context += self.ts_pos_enc\n",
    "        # Categorical\n",
    "        if self.to_cat_emb is not None: \n",
    "            x_cat = torch.cat([e(x_cat[:,i]).unsqueeze(1) for i,e in enumerate(self.to_cat_emb)], 1)\n",
    "            x_cat += self.cat_pos_enc\n",
    "            context = torch.cat([context, x_cat], 1)\n",
    "        # Continuous\n",
    "        if self.to_cont_emb is not None:\n",
    "            x_cont = torch.cat([e(x_cont[:,i].unsqueeze(1).unsqueeze(2)) for i,e in enumerate(self.to_cont_emb)], 1)\n",
    "            x_cont += self.cont_pos_enc\n",
    "            context = torch.cat([context, x_cont], 1)\n",
    "        context += self.pos_enc\n",
    "        \n",
    "        # Latent array\n",
    "        x = self.latent_array.repeat(context.shape[0], 1, 1)\n",
    "        \n",
    "        # Cross-attention & Latent transformer\n",
    "        for i, attn in enumerate(self.attn):\n",
    "            x = attn[0](x, context=context) + x # cross-attention\n",
    "            if self.self_per_cross_attn != 0:\n",
    "                x = attn[1](x) + x              # latent transformer\n",
    "\n",
    "        x = x.transpose(1,2)\n",
    "        \n",
    "        #Head\n",
    "        out = self.head(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.basics import *\n",
    "from tsai.data.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Feature Extraction: 100%|██████████████████████████████████████████| 30/30 [00:00<00:00, 189.16it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(60, 11)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#|extras\n",
    "dsid = 'OliveOil'\n",
    "X, y, splits = get_UCR_data(dsid, split_data=False)\n",
    "ts_features_df = get_ts_features(X, y)\n",
    "ts_features_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|extras\n",
    "# raw ts\n",
    "tfms  = [None, [Categorize()]]\n",
    "batch_tfms = TSStandardize(by_sample=True)\n",
    "ts_dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms)\n",
    "\n",
    "# ts features\n",
    "cat_names = None\n",
    "cont_names = ts_features_df.columns[:-2]\n",
    "y_names = 'target'\n",
    "tab_dls = get_tabular_dls(ts_features_df, cat_names=cat_names, cont_names=cont_names, y_names=y_names, splits=splits)\n",
    "\n",
    "# mixed\n",
    "mixed_dls = get_mixed_dls(ts_dls, tab_dls)\n",
    "xb, yb = mixed_dls.one_batch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|extras\n",
    "model = TSPerceiver(ts_dls.vars, ts_dls.c, ts_dls.len, cat_szs=0, \n",
    "                    # n_cont=0, \n",
    "                    n_cont=xb[1][1].shape[1], \n",
    "                    n_latents=128, d_latent=128, n_layers=3, self_per_cross_attn=1, share_weights=True,\n",
    "                    cross_n_heads=16, self_n_heads=16, d_head=None, attn_dropout=0., fc_dropout=0.).to(device)\n",
    "test_eq(model(xb).shape, (yb.shape[0], len(np.unique(y))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/123_models.TSPerceiver.ipynb saved at 2022-11-09 13:11:59\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 09/11/22 13:12:02 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
